from PIL import Image
import os
import numpy as np

def key(s):
    return int(s.split('.')[0][1:])

#将图像拉平合并到一个矩阵中，并存入npz文件方便读取
if __name__ == '__main__':
    root = './YALE'
    files = os.listdir(root)
    files.sort(key=key)
    data_mat = [[] for _ in range(15)]
    
    for file in files:
        #读取图像并加入矩阵
        file_path = os.path.join(root, file)
        img = Image.open(file_path)
        npimg = np.array(img)
        npimg = npimg.flatten()
        #根据文件名生成对应标签
        ID = int(file.split('.')[0][1:])
        lable = int(np.ceil(ID / 11) - 1)
        data_mat[lable].append(npimg)

    #存入npz文件
    data_mat = np.array(data_mat, dtype=np.float32)
    np.savez('data.npz', data=data_mat)